'''
Created on Jul 17, 2012

@author: Xindong
'''

from shortcuts import execute_select, log_info
import os


str_body = """<?xml version="1.0" encoding="UTF-8"?>
<gexf xmlns="http://www.gexf.net/1.2draft" 
    xmlns:viz="http://www.gexf.net/1.2draft/viz" 
    xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" 
    xsi:schemaLocation="http://www.gexf.net/1.2draft 
                        http://www.gexf.net/1.2draft/gexf.xsd" 
    version="1.2">
    <graph mode="static" defaultedgetype="undirected">
        <attributes class="node" mode="static">
            <attribute id="name" title="name" type="string"/>
            <attribute id="lat" title="latitude" type="double"/>
            <attribute id="lon" title="longitude" type="double"/>
            <attribute id="level" title="level" type="integer"/>
            <attribute id="dataset" title="dataset" type="integer"/>
        </attributes>
        <attributes class="edge" mode="static">
            <attribute id="level" title="level" type="integer"/>
        </attributes>
        <nodes>%s
        </nodes>
        <edges>%s
        </edges>
    </graph>
</gexf>"""


color = {}
############# R ## G ## B #
color[0] = (0xD0, 0xD0, 0xD0)
color[1] = (0xFF, 0xFF, 0x00)
color[2] = (0xFF, 0x00, 0x00)
color[3] = (0x00, 0xFF, 0x00)

color[4] = (0xFF, 0xA5, 0x00)
color[5] = (0xFF, 0x00, 0xFF)


size = {}
size[0] = 10
size[1] = 10
size[2] = 15
size[3] = 20

size[4] = 25
size[5] = 30


def node_gexf(node_id, latitude = 0, longitude = 0, level = 0, name = "Null"):
    node_color = 'r="%d" g="%d" b="%d" a="1"' % color[level]
    node_size = size[level]
    dataset_id = node_id[0]
    name = "NULL"
    return """
            <node id="%s" label="%s">
                <attvalues>
                    <attvalue for="lat" value="%f"/>
                    <attvalue for="lon" value="%f"/>
                    <attvalue for="level" value="%d"/>
                    <attvalue for="dataset" value="%d"/>
                    <attvalue for="name" value="%s"/>
                </attvalues>
                <viz:size value="%d"/>
                <viz:color %s />
            </node>""" % (node_id, node_id, 
                          latitude, longitude, level, dataset_id, name,  
                          node_size, 
                          node_color)


def edge_gexf(arc_id, dep_node, arr_node, level = 0):
    edge_color = 'r="%d" g="%d" b="%d" a="1"' % color[level]
    arc_label = arc_id.split("@")[0]
    return """
            <edge id="%s" label="%s" source="%s" target="%s" weight="1">
                <attvalues>
                    <attvalue for="level" value="%d"/>
                </attvalues>
                <viz:color %s />
            </edge>""" % (arc_id, arc_label, dep_node, arr_node, 
                          level, 
                          edge_color)


def graph_to_gexf(nodes, arcs, filename):
    ''' nodes
            key: node_id
            value: a dict of latitude, longitude, level, ...
        arcs
            key: (dep, arr
            value: a dict of arc_id, level, ...
    '''
    nodes_str = ""
    edges_str = ""
    
    for (node_id, node) in nodes.iteritems():
        nodes_str += node_gexf(node_id, 
                               node["latitude"], node["longitude"], 
                               node["level"], node["name"])
    
    for ((dep_node, arr_node), arc) in arcs.iteritems():
        arc_id = arc["arc_id"]
        level = arc["level"]
        edges_str += edge_gexf(arc_id, dep_node, arr_node, level)
    
    graph_out = str_body % (nodes_str, edges_str)
    with open(os.environ['HOME'] + "\\output\\%s.gexf" % filename, 'w') as f:
        f.write(graph_out)
    log_info("Exported: " + filename)


def load_nodes(dataset_list):
    #log_info("Load node coordinates")
    nodes = dict()
    node_names = load_stop_names(dataset_list)
    for dataset_name in dataset_list:
        results = execute_select("""
            SELECT dataset_id, node_id, latitude, longitude
            FROM %s_graph_nodes
        """ % dataset_name)
        for item in results:
            node_id = (int(item[0]), int(item[1]))
            latitude, longitude = int(item[2])/1e+5, int(item[3])/1e+5
            name = node_names[node_id]
            nodes[node_id] = {"latitude": latitude, "longitude": longitude, "level": 0, "name": name}
    return nodes


def load_stop_names(dataset_list):
    #log_info("Load node coordinates")
    node_names = dict()
    for dataset_name in dataset_list:
        results = execute_select("""
            SELECT dataset_id, graph_node_id, stop_name
            FROM %s_stops
        """ % dataset_name)
        for item in results:
            node_id = (int(item[0]), int(item[1]))
            name = item[2]
            node_names[node_id] = name
    return node_names


def load_arcs(dataset_list):
    #log_info("Load original graphs")
    arcs = dict()
    for dataset_name in dataset_list:
        results = execute_select("""
            SELECT dataset_id, dep_node, arr_node, min(cost)
            FROM %s_graph_arcs
            GROUP BY dep_node, arr_node
        """ % dataset_name)
        for item in results:
            dep_node = (int(item[0]), int(item[1]))
            arr_node = (int(item[0]), int(item[2]))
            arc_id = "%s-%s" % (dep_node, arr_node)
            cost = int(item[3])
            arcs[(dep_node, arr_node)] = {"arc_id": arc_id, "level": 0, "cost": cost}
    return arcs


def load_highway_graph(dataset_ids = None):
    #log_info("Load highway graphs")
    highway_arcs = {}
    results = execute_select("""
        SELECT dep_dataset, dep_node, arr_dataset, arr_node, level, cost
        FROM global_graph ORDER BY level
    """)
    for item in results:
        dep_node = (int(item[0]), int(item[1]))
        arr_node = (int(item[2]), int(item[3]))
        arc_id = "%s-%s" % (dep_node, arr_node)
        level, cost = int(item[4]), int(item[5])
        highway_arcs[(dep_node, arr_node)] = {"arc_id": arc_id, "level": level, "cost": cost}
    return highway_arcs


def output_highway_graph(dataset_list, filename):
    log_info("Export highway graph to gexf ... " + str(dataset_list))
    nodes = load_nodes(dataset_list)
    arcs = load_arcs(dataset_list)
    highlight_arcs = load_highway_graph()
    for ((dep_node, arr_node), harc) in highlight_arcs.iteritems():
        if dep_node in nodes and arr_node in nodes:
            nodes[dep_node]["level"] = harc["level"]
            nodes[arr_node]["level"] = harc["level"]
            arcs[(dep_node, arr_node)] = harc
    graph_to_gexf(nodes, arcs, filename)


def output_prefetched_routes(dataset_list, filename, path, route_set):
    log_info("Export prefetched routes to gexf ... " + str(dataset_list))
    nodes = load_nodes(dataset_list)
    arcs = load_arcs(dataset_list)
    for node in path:
        nodes[node]["level"] = 4
    
    results = execute_select("""
        SELECT dataset_id, dep_node, arr_node
        FROM gtfs_graph_arcs
        WHERE %s
        GROUP BY dep_node, arr_node
    """ % " OR ".join(["dataset_id=%d AND route=%d" % route for route in route_set]))
    for item in results:
        dep_node, arr_node = (int(item[0]), int(item[1]), 5), (int(item[0]), int(item[2]), 5)
        dep_pnode, arr_pnode = (int(item[0]), int(item[1])), (int(item[0]), int(item[2]))
        arc_id = "%s-%s" % (dep_node, arr_node)
        nodes[dep_node] = {"latitude": nodes[dep_pnode]["latitude"],
                           "longitude": nodes[dep_pnode]["longitude"],
                           "name": nodes[dep_pnode]["name"],
                           "level": 5}
        nodes[arr_node] = {"latitude": nodes[arr_pnode]["latitude"],
                           "longitude": nodes[arr_pnode]["longitude"],
                           "name": nodes[arr_pnode]["name"],
                           "level": 5}
        arcs[(dep_node, arr_node)] = {"arc_id": arc_id, "level": 5, "cost": 1}
    
    graph_to_gexf(nodes, arcs, filename)


def output_martins_solution(dataset_list, filename, paths):
    ''' path: [(dataset_id, node_id, sub_node_id(arc)), ...]
    '''
    log_info("Export martins solutions to gexf ... ")
    nodes = load_nodes(dataset_list)
    arcs = load_arcs(dataset_list)
    for i, path in enumerate(paths[:3]):
        #for x in path:
        #    print("%f, %f" % (nodes[(x[0], x[1])]["latitude"], nodes[x[0], x[1]]["longitude"]))
        for x, y in zip(path[:-1], path[1:]):
            dep_arc, arr_arc = (x[0], x[2]), (y[0], y[2])
            dep_pnode, arr_pnode = (x[0], x[1]), (y[0], y[1])
            dep_node, arr_node = (x[0], x[1], i + 1), (y[0], y[1], i + 1)
            nodes[dep_node] = {"latitude": nodes[dep_pnode]["latitude"], 
                               "longitude": nodes[dep_pnode]["longitude"],
                               "name": nodes[dep_pnode]["name"],
                               "level": i + 1}
            nodes[arr_node] = {"latitude": nodes[arr_pnode]["latitude"], 
                               "longitude": nodes[arr_pnode]["longitude"],
                               "name": nodes[arr_pnode]["name"], 
                               "level": i + 1}
            if (dep_node != arr_node):
                if dep_arc == arr_arc:
                    item = execute_select("""
                        SELECT dataset_id, route FROM gtfs_graph_arcs
                        WHERE dataset_id=%d AND arc_id=%d LIMIT 1
                    """ % dep_arc)[0]
                    route = (int(item[0]), int(item[1]))
                else:
                    route = 'walk'
                arc_id = "%s@%s-%s" % (route, dep_node, arr_node)
                arcs[(dep_node, arr_node)] = {"arc_id": arc_id, "level": i + 1, "cost": 1}
    graph_to_gexf(nodes, arcs, filename)


def output_time_dependent_graph(graph, node_coors, filename):
    nodes = {}
    arcs = {}
    for dep in graph:
        pnode = (dep[0], dep[1])
        nodes[dep] = {"latitude": node_coors[pnode][0]/1e+5, "longitude": node_coors[pnode][1]/1e+5, "level": 0}
        for arr, val in graph[dep].iteritems():
            if "route_arc" in val:
                arc_id = val["route_arc"]
                arcs[(dep, arr)] = {"arc_id": arc_id, "level": 1}
            elif "transfer_arc" in val:
                arc_id = str(dep) + " T " + str(arr)
                arcs[(dep, arr)] = {"arc_id": arc_id, "level": 2}
            elif "walking_arc" in val:
                arc_id = str(dep) + " W " + str(arr)
                arcs[(dep, arr)] = {"arc_id": arc_id, "level": 3}
            else:
                arc_id = str(dep) + " -> " + str(arr)
                arcs[(dep, arr)] = {"arc_id": arc_id, "level": 0}
    graph_to_gexf(nodes, arcs, filename)


if __name__ == "__main__":
    #dataset_list = ["nyctrs", "njtrs", "nynjpa", "mtrnth", "lirail"]
    #output_highway_graph(dataset_list, "graph")
    pass

